feat(body_axis): add skeleton-agnostic AP inference#945
feat(body_axis): add skeleton-agnostic AP inference#945khan-u wants to merge 21 commits intoneuroinformatics-unit:mainfrom
Conversation
for more information, see https://pre-commit.ci
…compute_polarization
…tion, edge case handling, and simplified tests
…ts, clarify orientation vs heading terminology
…n for polarization
5cd79d6 to
01d16a8
Compare
01d16a8 to
51866c9
Compare
for more information, see https://pre-commit.ci
cbc1c25 to
1bd1618
Compare
c78dce3 to
615666c
Compare
358e817 to
b1df3b9
Compare
for more information, see https://pre-commit.ci
Self-axis grid search workflow for tuning default thresholds of 3-step filter cascadeassumes only that keypoints exhibit lateral alignment about the body axis Datasetshttps://dreem.sleap.ai/0.5.1/datasets
Self-Axis Workflow (redesigned as shared-axis in next comment)Handling uncheckable pairs Result:
|
|
The self-axis grid search architecture optimizes thresholds for an evaluation that doesn't test generalization
A separate validation script applied the grid-search-winning thresholds to all 41 individuals. The self-axis evaluation (Pass 3) confirmed 100% accuracy for best individuals (SANITY CHECK):
However, when projecting each individual onto their dataset's reference axis (derived from that dataset's best individual), accuracy dropped substantially:
ap_validation_20260407_002942.log Footnotes
|
Grid-Search for Shared-Axis EvaluationsMotivation The filter cascade thresholds must satisfy a robustness constraint:
This constraint is non-trivial because individuals within a dataset exhibit: Effect on Shared-Axis Projections
Self-axis evaluation (demonstrated above) sidesteps this variability entirely:
Shared-axis evaluation (current goal) confronts this variability directly:
Shared-Axis Workflow Handling uncheckable pairs TO DO:
|
Revised
|
…ons to resolve potential confusion
|



Description
Skeleton-Agnostic Body-Axis Inference Pipeline
This PR draft depends on #875 (also a draft). For review, the work here covers mostly skeleton-agnostic body-axis inference, and its utility for computing polarization is emphasized here:
→
compute_polarization – Revised API#875 (comment)
For the latest updates on this draft PR, see the comments below.
→
body_axis – Revised APIWhat is this PR
Why is this PR needed?
This PR addresses a practical problem in computing orientation polarization from body-axis keypoints: the user must specify which keypoint pair defines the posterior → anterior axis, i.e., the
from_nodeandto_node.The AP validation pipeline introduced here automatically assesses that choice and, when the input pair is not well supported, suggests a better alternative.
What does this PR do?
The core question it answers is:
The pipeline is implemented in a new module,
movement/kinematics/body_axis.py, which provides:ValidateAPConfig: configuration dataclass for all tunable parametersFrameSelection: dataclass bundling frame indices and segment assignmentsAPNodePairReport: dataclass with detailed AP pair-evaluation resultsvalidate_ap(): main validation function for a single individualrun_ap_validation(): multi-individual validation entry pointThe validation is called by
compute_polarization()as a side-channel diagnostic whenvalidate_ap=Trueandbody_axis_keypointsis provided.The validation results do not affect the polarization computation itself, but are stored in
polarization.attrs["ap_validation_result"]for the user to inspect.Configuration parameters for the various thresholds can be supplied by the user via
ap_validation_config.The pipeline in
validate_ap()works through these stages:1. Tiered validity
min_valid_fracof keypoints present and ≥ 2 total2. Bounding-box centroid computation
3. High-motion segment detection
window_lenspeed samples, advanced bystridesamples, compute median speeds:pct_threshpercentile of all valid-window mediansmin_run_len)4. Tier-2 filtering on segments
5. Centroid-centered skeleton construction
6. Postural clustering
postural_var_ratio_threshand at least 6 frames are available:7. PCA on the average skeleton
PC1[1] >= 0(y-component non-negative)PC2[0] >= 0(x-component non-negative)8. Anterior direction inference via velocity voting
compute_polarization():confidence_floor:9. Input AP Node-Pair Filter Cascade
Given a candidate keypoint pair (e.g., tail_base → nose), it evaluates quality through:
Step I - Lateral alignment filter
lateral_thresh_pct(default: 50th percentile) are eliminatedmax == min), all normalized offsets are set to 0 and all nodes passStep II - Opposite-sides constraint
Step III - Distal/proximal classification
edge_thresh_pct(default: 70th percentile)Loss diagnostics
10. Suggested Pair
lateral_stdis the normalized standard deviation of each node's lateral offset over timeorder_pair_by_ab()so that:body_axis_keypoints=(from_node, to_node)conventionmax_separation_distal_nodesormax_separation_nodeson theAPNodePairReportinput_pair_order_matches_inference) compares the input pair's AP coordinates:Trueiffrom_node's AP coordinate <to_node's AP coordinate (i.e.,from_nodeis more posterior)11. Mutually Exclusive Scenarios
accept/warn) based on whether the input pair survived all filters, is distal, has maximum separation, etc.Return: xarray Attribute `ap_validation_result`
When
validate_ap=Trueandbody_axis_keypointsis provided,compute_polarization()stores results inpolarization.attrs["ap_validation_result"]:{ "all_results": [<per-individual result dicts>], "best_idx": int # index into all_results (highest R×M score) }Per-Individual Result Dict Fields
successanterior_signvote_marginresultant_lengthcirc_mean_dirnum_selected_framesnum_clustersprimary_clusterPC1(2,)PC2(2,)avg_skeleton(n_keypoints, 2)vel_projs_pc1lateral_stdlongitudinal_stdpair_reportAPNodePairReportwith detailed AP pair-evaluationlog_linescompute_polarization(), which hardcodesverbose=False)error_msgindividualrun_ap_validation())The
pair_reportfield containsscenario(1-13) andoutcome("accept"/"warn") from the flowchart above.The
collective.pyAPI was recently redesigned.→ Planned API revision – Newer Comment
Usage (will be deprecated)
How has this PR been tested?
1. Using
test_body_axis.py.Comprehensive testing pending implementation of the planned refactor
TestValidateAPConfig(2 tests)Parameter-boundary validation for the
ValidateAPConfigdataclass. Tests all 12 configurable fields:test_invalid_config_values_raise* (23 parametrized cases)- Each field is tested with out-of-range values: - negative fractions - values above 1.0 for [0, 1] fields - zero or negative integers for count fields - floats where integers are required - All must raise `ValueError` with a message matching `"must be"`.
test_valid_config_does_not_raiseConstructs a
ValidateAPConfigwith all fields set to non-default valid values and asserts that no exception is raised.The 12 fields tested are all listed in the config table below.
2. Empirical Validation (SANITY CHECK)
The 3-step filter-cascade thresholds and pair-scoring method were empirically optimized via two validation studies on 5 diverse multi-animal datasets (2Flies, 2Mice, 4Gerbils, 5Mice, 2Bees) with hand-curated ground-truth AP node rankings.
Deprecated exhaustive grid search.
Analysis 2a: Grid Search over Design and Parameter Space
Example Script | Detailed Log | Results JSON
Method: Exhaustive grid search over 616,896 configurations testing several method categories:
Midpoint:
Lateral threshold:
Edge threshold:
Normalization:
Formula:
Pair scoring:
Weights:
For each configuration:
the best individual per dataset was selected via max R×M
then the 3-step filter cascade was applied to identify the suggested AP pair
results were scored by:
Results + Implementation
Multiple configurations achieved 5/5 datasets with both nodes in GT and correct ordering.
Tied (with many others) for the top-ranked configuration:
Configuration: (`ValidateAPConfig`)
All configurable thresholds are collected in a single dataclass in
movement.kinematics.body_axis. Users pass overrides as a dict viaap_validation_config; any omitted key uses its default. The config below represents the state the pipeline was in during same-axis GT node-pair accuracy evaluation.min_valid_fracwindow_lenstridepct_threshmin_run_lenpostural_var_ratio_threshmax_clustersmin(max_clusters, n//2)).confidence_floorlateral_thresh_pctedge_thresh_pctlateral_var_weightlongitudinal_var_weightPer-Dataset Filter Cascade Results (5 datasets, self-axis)
Per-dataset filter-cascade results with the optimal (top-ranked) configuration applied.
Top row: shows average skeletons (best individual by R×M within each dataset) with PC1 axes, GT nodes marked, and suggested pairs labeled
Bottom-left: shows GT-node coverage per dataset
Bottom-right: shows the filter-cascade progression:
Steps 1-2 filter pairs
Step 3 classifies survivors as:
the ★ marker indicates which segment the suggested pair was selected from
all 5 datasets show correct AP pair identification (SANITY CHECK)
Example Script | Detailed Log
Analysis 2b: Metric Evaluation for “Best” Individual Selection
Example Script | Detailed Log
Method:
For each of 5 metrics:
Metrics tested:
Resultant Length × Vote Margin: composite locomotion-quality score (aka R×M)
PC1 variance ratio: ratio of PC1 to PC2 singular values from SVD on the average skeleton
Mean inverse lateral variance: average of 1/σ for each keypoint's lateral offset over time
Agreement score: fraction of other individuals whose GT-node ordering (projected onto their own AP axis) matches this individual's ordering
Skeleton completeness: fraction of keypoints valid (non-NaN) in the average skeleton
Results:
mean_inv_lateral_varboth achieve perfect reference selection (5/5)Detailed Per-Dataset Breakdown (Reference Selection)
The 2Bees case is particularly instructive:
track_0has a higher PC1 variance ratio, higher skeleton completeness, and an equal agreement score, but 0% GT accuracymean_inv_lateral_var) correctly identifytrack_1as the trustworthy referenceR×M correctly selects
track_1(100% GT accuracy) overtrack_0(0% GT accuracy), despitetrack_0having a higher PC1 variance ratio and higher skeleton completeness.Other Datasets
2Flies (
track_0):2Mice (
track_0):4Gerbils (
pup_unshaved):5Mice (
track_0):Flowchart: Input AP Node-Pair Filter Cascade
Survivors:
Distal pair:
edge_thresh_pctpercentileProximal pair:
edge_thresh_pctpercentileMax-sep overall:
Max-sep distal:
Input pair rank:
AP Node-Pair Filter Cascade Flowchart
References
Is this a breaking change?
No.
Does this PR require an update to the documentation?
No - API docs auto-generate from docstrings.
Checklist
Future Refactoring
The
body_axis.pymodule (~2,900 lines) is intentionally monolithic in this PR to simplify review and iteration. Once the API stabilizes, general-purpose functionality could be extracted into existing or new utility modules.For example: